#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Implements an iterative graph-filtering pipeline that progressively refines
a directed Wikipedia biographies network by enforcing minimum out-degree,
minimum population per century–field pair, and minimum field-of-activity
richness per century. Each pass produces diagnostic visualizations and
intermediate outputs, converging when the node set stabilizes.
'''

import os
import pickle
from typing import Optional

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import seaborn as sns

from connecting_people.graph.directed_graph import create_directed_graph


# filter 1: outdegree-based node filtering and graph to dataframe conversion

def filter1_nodes_by_outdegree(G: nx.DiGraph, min_outdegree: int) -> nx.DiGraph:
    """filter nodes by minimum out-degree and return the resulting subgraph."""
    # keep only nodes whose out-degree is at least the given threshold
    nodes_to_keep = [n for n in G.nodes if G.out_degree(n) >= min_outdegree]

    # extract the induced subgraph on those nodes
    G_filtered = G.subgraph(nodes_to_keep).copy()
    return G_filtered


def graph_to_dataframe(G: nx.DiGraph) -> pd.DataFrame:
    """convert a directed graph with node attributes back into an edge dataframe."""
    rows = []
    for source, target, data in G.edges(data=True):
        rows.append(
            {
                "Source": source,
                "YOB_S": data["YOB_S"],
                "COB_S": data["COB_S"],
                "field_of_human_activity": data["field_of_human_activity"],
                "Target": target,
                "YOB_T": data["YOB_T"],
                "COB_T": data["COB_T"],
                "field_of_human_activity_T": data["field_of_human_activity_T"],
            }
        )
    return pd.DataFrame(rows)


# filter 2: minimum people per (century, field_of_human_activity) + heatmap

def compute_people_per_cluster_century(df: pd.DataFrame) -> pd.DataFrame:
    """compute number of unique people per field_of_human_activity and century, excluding 21st century."""
    df_source = df[["Source", "YOB_S", "field_of_human_activity"]].copy()
    df_source.columns = ["person", "YOB", "field_of_human_activity"]

    df_target = df[["Target", "YOB_T", "field_of_human_activity_T"]].copy()
    df_target.columns = ["person", "YOB", "field_of_human_activity_T"]

    df_combined = pd.concat([df_source, df_target], ignore_index=True).drop_duplicates()

    df_combined = df_combined[df_combined["YOB"].notnull()]
    df_combined = df_combined[df_combined["YOB"] != 0]

    # compute century: bce (yob < 0) -> yob // 100, ce (yob > 0) -> yob // 100 + 1
    df_combined["century"] = df_combined["YOB"].apply(
        lambda yob: (yob // 100) if yob < 0 else (yob // 100 + 1)
    ).astype(int)

    # explicitly exclude 21st century
    df_combined = df_combined[df_combined["century"] != 21]

    return df_combined


def plot_cluster_heatmap(df_people: pd.DataFrame, min_people: int = 5) -> None:
    """plot a heatmap of number of unique people per field_of_human_activity per century."""
    counts = (
        df_people.groupby(["century", "field_of_human_activity"])["person"]
        .nunique()
        .reset_index()
    )
    counts = counts.rename(columns={"person": "num_people"})

    heatmap_data = counts.pivot(
        index="century",
        columns="field_of_human_activity",
        values="num_people",
    )

    all_centuries = sorted(df_people["century"].unique())
    heatmap_data = heatmap_data.reindex(all_centuries).fillna(0)

    # create custom labels: 0 -> "1st CE", negative -> "X BCE", positive -> "X CE"
    labels = [
        f"{abs(c)} {'BCE' if c < 0 else 'CE'}" if c != 0 else "1st CE"
        for c in all_centuries
    ]

    cmap = mcolors.ListedColormap(["gray", "green"])
    bounds = [0, min_people, heatmap_data.values.max() + 1]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    fig, ax = plt.subplots(figsize=(10, 8))

    sns.heatmap(
        heatmap_data,
        cmap=cmap,
        norm=norm,
        cbar=False,
        linewidths=0.5,
        linecolor="black",
        ax=ax,
    )

    ax.set_title(
        "Unique People per Field of Human Activity per Century (Before Filter 2)",
        fontsize=14,
    )
    ax.set_xlabel("field_of_human_activity")
    ax.set_ylabel("Century")
    ax.set_yticks(range(len(all_centuries)))
    ax.set_yticklabels(labels)
    plt.yticks(rotation=0)
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    os.makedirs("figures/filters", exist_ok=True)
    plt.savefig("figures/filters/filter2.png", dpi=500)
    plt.show()


def run_pipeline_filter2(df: pd.DataFrame, min_people: int = 5) -> pd.DataFrame:
    """run filter 2, keeping only people in (century, field_of_human_activity) cells with at least min_people individuals."""
    df_people = compute_people_per_cluster_century(df)

    counts = (
        df_people.groupby(["century", "field_of_human_activity"])["person"]
        .nunique()
        .reset_index()
    )
    counts = counts.rename(columns={"person": "num_people"})

    # plot heatmap before applying the filter
    plot_cluster_heatmap(df_people, min_people=min_people)

    heatmap_data = counts.pivot(
        index="century",
        columns="field_of_human_activity",
        values="num_people",
    ).fillna(0)

    # identify (century, field) pairs that pass the threshold
    passed_pairs = counts[counts["num_people"] >= min_people][
        ["century", "field_of_human_activity"]
    ]
    passed_pairs = {
        (row["century"], row["field_of_human_activity"])
        for _, row in passed_pairs.iterrows()
    }

    valid_people = df_people[
        df_people.apply(
            lambda row: (row["century"], row["field_of_human_activity"]) in passed_pairs,
            axis=1,
        )
    ]["person"].unique()
    valid_people = set(valid_people)

    mask_source_valid = df["Source"].isin(valid_people)
    mask_target_valid = df["Target"].isin(valid_people)

    final_mask = mask_source_valid & mask_target_valid

    df_after_filter2 = df[final_mask].copy()

    return df_after_filter2


# filter 3: minimum cluster richness per century + bar plot

def run_pipeline_filter3(df: pd.DataFrame, min_clusters: int = 3) -> pd.DataFrame:
    """
    filter 3: keep only centuries where the number of unique fields of human activity
    (source + target) is at least min_clusters, excluding 21st century.
    also produces a bar plot of cluster richness per century.
    """

    def add_century_column(df_local: pd.DataFrame, yob_column: str) -> pd.Series:
        """compute century correctly for both bce and ce years."""
        return df_local[yob_column].apply(
            lambda yob: (yob // 100) if yob < 0 else (yob // 100 + 1)
        ).astype(int)

    df["century_S"] = add_century_column(df, "YOB_S")
    df["century_T"] = add_century_column(df, "YOB_T")

    df_source_clusters = df[["century_S", "field_of_human_activity"]].rename(
        columns={"century_S": "century"}
    )
    df_target_clusters = df[
        ["century_T", "field_of_human_activity_T"]
    ].rename(
        columns={
            "century_T": "century",
            "field_of_human_activity_T": "field_of_human_activity",
        }
    )

    df_clusters = pd.concat(
        [df_source_clusters, df_target_clusters], ignore_index=True
    ).drop_duplicates()

    cluster_counts = (
        df_clusters.groupby("century")["field_of_human_activity"]
        .nunique()
        .reset_index()
    )
    cluster_counts = cluster_counts.rename(
        columns={"field_of_human_activity": "num_field_of_human_activity"}
    )

    def century_label(c: int) -> str:
        """format century as an ordinal label with bce/ce."""
        n = abs(int(c))
        suf = "th"
        if n % 100 not in (11, 12, 13):
            suf = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
        return f"{n}{suf} {'BCE' if c < 0 else 'CE'}"

    def plot_cluster_richness(
        cluster_counts_local: pd.DataFrame,
        min_clusters_local: int,
        save_path: Optional[str] = None,
    ) -> None:
        """plot number of fields of human activity per century, marking those above threshold."""
        cc = cluster_counts_local.copy()
        # exclude century 0 for plotting and enforce ordered categorical to remove numeric gaps
        cc = cc[cc["century"] != 0]
        ordered = sorted(cc["century"].unique())
        cc["century_cat"] = pd.Categorical(
            cc["century"], categories=ordered, ordered=True
        )

        fig, ax = plt.subplots(figsize=(14, 6))
        values = (
            cc.set_index("century_cat")
            .loc[ordered]["num_field_of_human_activity"]
            .values
        )
        bars = ax.bar(range(len(ordered)), values, color="gray")

        # color bars above threshold
        for bar, val in zip(bars, values):
            if val >= min_clusters_local:
                bar.set_color("green")

        ax.axhline(
            min_clusters_local,
            color="red",
            linestyle="--",
            linewidth=1.5,
            label=f"Threshold = {min_clusters_local}",
        )
        ax.set_xlabel("Century", fontsize=12)
        ax.set_ylabel(
            "Number of Fields of Human Activities (richness S)", fontsize=12
        )
        ax.set_xticks(range(len(ordered)))
        ax.set_xticklabels(
            [century_label(c) for c in ordered],
            rotation=45,
        )
        ax.set_title(
            "Century-wide Field of Human Activities Richness (Before Filter 3)",
            fontsize=14,
        )
        ax.legend()
        ax.grid(axis="y", linestyle="--", alpha=0.7)

        os.makedirs("figures/filters", exist_ok=True)
        if save_path:
            plt.savefig(save_path, bbox_inches="tight", dpi=500)
        plt.show()

    plot_cluster_richness(
        cluster_counts,
        min_clusters_local=min_clusters,
        save_path="figures/filters/filter3.png",
    )

    valid_centuries = cluster_counts[
        cluster_counts["num_field_of_human_activity"] >= min_clusters
    ]["century"].tolist()
    valid_centuries = [c for c in valid_centuries if c != 21]

    df_source_people = df[
        ["Source", "century_S", "field_of_human_activity"]
    ].rename(
        columns={"Source": "person", "century_S": "century"}
    )
    df_target_people = df[
        ["Target", "century_T", "field_of_human_activity_T"]
    ].rename(
        columns={
            "Target": "person",
            "century_T": "century",
            "field_of_human_activity_T": "field_of_human_activity",
        }
    )

    df_people = pd.concat(
        [df_source_people, df_target_people], ignore_index=True
    ).drop_duplicates()

    valid_people = df_people[df_people["century"].isin(valid_centuries)][
        "person"
    ].unique()
    valid_people = set(valid_people)

    mask_source_valid = df["Source"].isin(valid_people)
    mask_target_valid = df["Target"].isin(valid_people)

    final_mask = mask_source_valid & mask_target_valid

    df_after_filter3 = df[final_mask].copy()

    df_after_filter3.drop(columns=["century_S", "century_T"], inplace=True)

    # export the final filtered edges to disk for reuse
    os.makedirs("D:/Users/Paschalis/phd/data/connecting_people2", exist_ok=True)
    df_after_filter3.to_csv(
        "D:/Users/Paschalis/phd/data/connecting_people2/connecting_people_filtered2.csv",
        index=False,
        encoding="utf-8",
    )
    df_after_filter3.to_parquet(
        "D:/Users/Paschalis/phd/data/connecting_people2/connecting_people_filtered2.parquet",
        engine="pyarrow",
        compression="gzip",
    )

    return df_after_filter3


# iterative pipeline configuration and driver

experiment_config = {
    "min_outdegree": 3,
    "min_people": 5,
    "min_clusters": 5,
    "max_iterations": 100,
    "graph_path": "D:/Users/Paschalis/phd/data/connecting_people2/graph_before_filters.gpickle",
    "save_folder": "D:/Users/Paschalis/phd/data/connecting_people2/experiments_filters/exp4",
}


def run_full_iterative_pipeline_param(
    config: dict,
    df_start: Optional[pd.DataFrame] = None,
) -> pd.DataFrame:
    """
    run the full iterative filtering pipeline using a configuration dictionary.

    first iteration:
        - if df_start is provided, build the graph from df_start
        - otherwise, load the initial graph from disk (pickle)
    subsequent iterations:
        - build a directed graph from the last filtered dataframe

    at each iteration, the function:
        1. applies filter 1 (outdegree threshold)
        2. applies filter 2 (minimum people per century–field pair)
        3. applies filter 3 (minimum cluster richness per century)

    intermediate results and convergence information are written to the folder
    specified in the config dictionary.
    """
    # unpack configuration
    min_outdegree = config["min_outdegree"]
    min_people = config["min_people"]
    min_clusters = config["min_clusters"]
    max_iterations = config["max_iterations"]
    graph_path = config["graph_path"]
    save_folder = config["save_folder"]

    # ensure save folder exists
    os.makedirs(save_folder, exist_ok=True)

    convergence_log = []  # records iteration, node count, edge count

    iteration = 0
    prev_num_nodes = None
    df_current: Optional[pd.DataFrame] = df_start

    while iteration < max_iterations:
        iteration += 1

        # step 1: build or load the graph
        if iteration == 1 and df_current is None:
            # first iteration, no dataframe provided: load initial graph from disk
            with open(graph_path, "rb") as f:
                G = pickle.load(f)
        else:
            # either df_start was provided or df_current comes from previous iteration
            if df_current is None:
                raise ValueError(
                    "df_current is None when attempting to build graph; "
                    "provide df_start or a valid graph_path."
                )
            G = create_directed_graph(df_current)


        # filter 1: outdegree-based filtering
        G_filtered = filter1_nodes_by_outdegree(G, min_outdegree)
        df_after_filter1 = graph_to_dataframe(G_filtered)

        df_after_filter1_path = os.path.join(
            save_folder, f"df_after_filter1_iter{iteration}.csv"
        )
        df_after_filter1.to_csv(df_after_filter1_path, index=False)


        # filter 2
        df_after_filter2 = run_pipeline_filter2(
            df_after_filter1, min_people=min_people
        )
        df_after_filter2_path = os.path.join(
            save_folder, f"df_after_filter2_iter{iteration}.csv"
        )
        df_after_filter2.to_csv(df_after_filter2_path, index=False)

        # filter 3
        df_after_filter3 = run_pipeline_filter3(
            df_after_filter2, min_clusters=min_clusters
        )
        df_after_filter3_path = os.path.join(
            save_folder, f"df_after_filter3_iter{iteration}.csv"
        )
        df_after_filter3.to_csv(df_after_filter3_path, index=False)

        # count number of unique nodes
        current_nodes = set(
            pd.concat(
                [df_after_filter3["Source"], df_after_filter3["Target"]],
                ignore_index=True,
            ).unique()
        )
        current_num_nodes = len(current_nodes)
        current_num_edges = len(df_after_filter3)

        convergence_log.append(
            {
                "iteration": iteration,
                "num_nodes": current_num_nodes,
                "num_edges": current_num_edges,
            }
        )

        # simple convergence criterion: stop when node count stops changing
        if prev_num_nodes is not None and current_num_nodes == prev_num_nodes:
            break

        prev_num_nodes = current_num_nodes
        df_current = df_after_filter3.copy()

    print(f"\nFinal number of unique nodes: {current_num_nodes}")
    print(f"Iterations run: {iteration}")

    # save final result dataframe
    df_final_csv = os.path.join(save_folder, "df_final2.csv")
    df_current.to_csv(df_final_csv, index=False)
    print(f"Final result saved to: {df_final_csv}")

    # save convergence log
    convergence_csv = os.path.join(save_folder, "convergence_log2.csv")
    pd.DataFrame(convergence_log).to_csv(convergence_csv, index=False)
    print(f"Convergence log saved to: {convergence_csv}")

    return df_current